/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file softmax_tiling.cpp
 * \brief
 */

#include "lib/activation/logsoftmax_tiling.h"

#include <set>

#include "lib/activation/logsoftmax_tilingdata.h"
#include "impl/api_check/host_apicheck.h"

namespace optiling {
REGISTER_TILING_DATA_CLASS(LogSoftMaxTilingOpApi, LogSoftMaxTiling)
}
namespace AscendC {
constexpr uint32_t SOFTMAX_DEFAULT_BLK_SIZE = 32;
constexpr uint32_t SOFTMAX_HALF_SIZE = 2;
constexpr uint32_t SOFTMAX_FLOAT_SIZE = 4;
constexpr uint32_t BASIC_TILE_NUM = SOFTMAX_DEFAULT_BLK_SIZE / SOFTMAX_FLOAT_SIZE;
constexpr uint32_t SOFTMAX_BASICBLOCK_MIN_SIZE = 256;
constexpr uint32_t SOFTMAX_BASICBLOCK_UNIT = 64;
#define UNUSED __attribute__((unused))
static const std::set<uint32_t> SUPPORT_TYPESIZE = { SOFTMAX_HALF_SIZE, SOFTMAX_FLOAT_SIZE };
static constexpr const char LOG_SOFTMAX_GET_MAX[] = "GetLogSoftMaxMaxTmpSize";
static constexpr const char LOG_SOFTMAX_GET_MIN[] = "GetLogSoftMaxMinTmpSize";
static constexpr const char LOG_SOFTMAX_TILING[] = "LogSoftMaxTilingFunc";

inline std::vector<uint32_t> GetLastAxisShapeND(const ge::Shape srcShape)
{
    std::vector<uint32_t> ret;
    std::vector<int64_t> shapeDims = srcShape.GetDims();
    uint32_t calculateSize = 1;
    for (uint32_t i = 0; i < shapeDims.size(); i++) {
        calculateSize *= shapeDims[i];
    }

    if (shapeDims.size() >= 1) {
        const uint32_t srcK = shapeDims[shapeDims.size() - 1];
        const uint32_t srcM = calculateSize / srcK;
        ret = { srcM, srcK };
    }
    return ret;
}

inline void AdjustToBasicBlockBaseM(uint32_t& baseM, const uint32_t srcM, const uint32_t srcK)
{
    if (baseM > BASIC_TILE_NUM && srcM % BASIC_TILE_NUM == 0 && srcK % SOFTMAX_BASICBLOCK_UNIT == 0) { // basicblock
        baseM = baseM / BASIC_TILE_NUM * BASIC_TILE_NUM;
        while (srcM % baseM != 0) {
            baseM -= BASIC_TILE_NUM;
        }
        // max repeat only support 255
        while (baseM * srcK >= SOFTMAX_BASICBLOCK_UNIT * SOFTMAX_BASICBLOCK_MIN_SIZE) {
            baseM = baseM / SOFTMAX_HALF_SIZE;
        }
    }
}

uint32_t GetLogSoftMaxMaxTmpSize(const ge::Shape srcShape, const uint32_t dataTypeSize, UNUSED const bool isReuseSource)
{
    HighLevelApiCheck::SrcShapeSizeVerifyingParameters<LOG_SOFTMAX_GET_MAX>(srcShape.GetShapeSize(), dataTypeSize);
    HighLevelApiCheck::ShapeLastAxisAlignVerifyingParameters<LOG_SOFTMAX_GET_MAX>(srcShape, dataTypeSize,
        SOFTMAX_DEFAULT_BLK_SIZE);
    HighLevelApiCheck::TypeSizeVerifyingParameters<LOG_SOFTMAX_GET_MAX>(dataTypeSize, SUPPORT_TYPESIZE);
    HighLevelApiCheck::IsReuseSourceVerifyingParameters<LOG_SOFTMAX_GET_MAX>(isReuseSource);
    std::vector<uint32_t> retVec = GetLastAxisShapeND(srcShape);
    // the softmax shape size must be 2
    if (retVec.size() <= 1 || dataTypeSize == 0) {
        return 0;
    }
    const uint32_t srcM = retVec[0];
    const uint32_t srcK = retVec[1];
    const uint32_t elementNumPerBlk = SOFTMAX_DEFAULT_BLK_SIZE / dataTypeSize;
    const uint32_t needSize = srcM * (elementNumPerBlk + srcK + SOFTMAX_BASICBLOCK_UNIT);
    return needSize * SOFTMAX_FLOAT_SIZE;
}

uint32_t GetLogSoftMaxMinTmpSize(const ge::Shape srcShape, const uint32_t dataTypeSize, UNUSED const bool isReuseSource)
{
    HighLevelApiCheck::SrcShapeSizeVerifyingParameters<LOG_SOFTMAX_GET_MIN>(srcShape.GetShapeSize(), dataTypeSize);
    HighLevelApiCheck::ShapeLastAxisAlignVerifyingParameters<LOG_SOFTMAX_GET_MIN>(srcShape, dataTypeSize,
        SOFTMAX_DEFAULT_BLK_SIZE);
    HighLevelApiCheck::TypeSizeVerifyingParameters<LOG_SOFTMAX_GET_MIN>(dataTypeSize, SUPPORT_TYPESIZE);
    HighLevelApiCheck::IsReuseSourceVerifyingParameters<LOG_SOFTMAX_GET_MIN>(isReuseSource);
    std::vector<uint32_t> retVec = GetLastAxisShapeND(srcShape);
    // the softmax shape size must be 2
    if (retVec.size() <= 1 || dataTypeSize == 0) {
        return 0;
    }
    const uint32_t srcK = retVec[1];
    const uint32_t elementNumPerBlk = SOFTMAX_DEFAULT_BLK_SIZE / dataTypeSize;
    const uint32_t needSize = elementNumPerBlk + srcK + SOFTMAX_BASICBLOCK_UNIT;
    return needSize * SOFTMAX_FLOAT_SIZE;
}

void LogSoftMaxTilingFunc(const ge::Shape srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize,
    optiling::LogSoftMaxTiling& softmaxTiling)
{
    HighLevelApiCheck::SrcShapeSizeVerifyingParameters<LOG_SOFTMAX_TILING>(srcShape.GetShapeSize(), dataTypeSize);
    HighLevelApiCheck::ShapeLastAxisAlignVerifyingParameters<LOG_SOFTMAX_TILING>(srcShape, dataTypeSize,
        SOFTMAX_DEFAULT_BLK_SIZE);
    HighLevelApiCheck::TypeSizeVerifyingParameters<LOG_SOFTMAX_TILING>(dataTypeSize, SUPPORT_TYPESIZE);
    HighLevelApiCheck::LocalWorkSpaceSizeVerifyingParameters<LOG_SOFTMAX_TILING>(localWorkSpaceSize);
    std::vector<uint32_t> retVec = GetLastAxisShapeND(srcShape);
    if (retVec.size() <= 1 || dataTypeSize == 0) {
        return;
    }
    const uint32_t elementNumPerBlk = SOFTMAX_DEFAULT_BLK_SIZE / dataTypeSize;
    const uint32_t workLocalSize = localWorkSpaceSize / SOFTMAX_FLOAT_SIZE;
    const uint32_t srcK = retVec[1];
    const uint32_t srcM = retVec[0];
    uint32_t baseM = std::min(workLocalSize / (elementNumPerBlk + srcK + SOFTMAX_BASICBLOCK_UNIT), srcM);
    if (baseM < srcM && baseM > BASIC_TILE_NUM) {
        baseM = baseM / BASIC_TILE_NUM * BASIC_TILE_NUM;
    }

    AdjustToBasicBlockBaseM(baseM, srcM, srcK);

    softmaxTiling.set_srcM(srcM);
    softmaxTiling.set_srcK(srcK);
    softmaxTiling.set_srcSize(srcM * srcK);

    softmaxTiling.set_outMaxM(srcM);             // output dstMax
    softmaxTiling.set_outMaxK(elementNumPerBlk); // output dstMax
    softmaxTiling.set_outMaxSize(srcM * elementNumPerBlk);

    softmaxTiling.set_splitM(baseM);
    softmaxTiling.set_splitK(srcK);
    softmaxTiling.set_splitSize(baseM * srcK);

    softmaxTiling.set_reduceM(baseM);
    softmaxTiling.set_reduceK(elementNumPerBlk);
    softmaxTiling.set_reduceSize(baseM * elementNumPerBlk);

    const uint32_t range = srcM / baseM;
    const uint32_t tail = srcM % baseM;
    softmaxTiling.set_rangeM(range);
    softmaxTiling.set_tailM(tail);

    softmaxTiling.set_tailSplitSize(tail * srcK);
    softmaxTiling.set_tailReduceSize(tail * elementNumPerBlk);
}
}